{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h1>Federated Datasets</h1>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<p style=\"text-align:justify\">This example demonstrates how you can develop your custom datasets that work with federated data loaders.Unfortunately the regular torch datasets do not always work directly with federated data loaders. This example demonstrates how to develop a custom dataset that works with federated data loaders highlighting the differences. Further, you could also use Syft's Base dataset feature which simplifies creating federated datasets.To demonstrate these we will show you how to load the SVHN (Street View House Numbers) dataset and convert it to a federated dataset.</p>\n",
    "\n",
    "SHVN Dataset: <a href=\"http://ufldl.stanford.edu/housenumbers/\">Link</a>\n",
    "\n",
    "Authored By:\n",
    "\n",
    "Hrishikesh Kamath - GitHub: @<a href=\"http://github.com/kamathhrishi\">kamathhrishi</a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset\n",
    "import syft as sy  \n",
    "import torch\n",
    "import urllib\n",
    "from pathlib import Path\n",
    "import os"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define the workers you would want to distribute the data to."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hook = sy.TorchHook(torch)  # <-- NEW: hook PyTorch ie add extra functionalities to support Federated Learning\n",
    "bob = sy.VirtualWorker(hook, id=\"bob\")  # <-- NEW: define remote worker bob\n",
    "alice = sy.VirtualWorker(hook, id=\"alice\")  # <-- NEW: and alice"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The hyperparameters initialized while training the model. For , this tutorial we only need the batch size. In general practice creating a seperate class with hyperparameters or a dictionary is a good practice. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Arguments():\n",
    "    def __init__(self):\n",
    "        self.batch_size = 1\n",
    "        self.test_batch_size = 1000\n",
    "        self.seed = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = Arguments()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h2>Load Data</h2>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#create a function for checking if the dataset does indeed exist\n",
    "def dataset_exists():\n",
    "    return (os.path.isfile('./data/train_32x32.mat') and\n",
    "            os.path.isfile('./data/test_32x32.mat'))\n",
    "\n",
    "    \n",
    "#If the dataset does not exist, then proceed to download the dataset anew\n",
    "if not dataset_exists():\n",
    "    Path('./data/').mkdir(parents=True, exist_ok=True)\n",
    "    #If the dataset does not already exist, let's download the dataset directly from the URL where it is hosted\n",
    "    print('Downloading the dataset with urllib2 to the data directory...')\n",
    "    url1 = 'http://ufldl.stanford.edu/housenumbers/train_32x32.mat'\n",
    "    urllib.request.urlretrieve(url1, './data/train_32x32.mat')\n",
    "    url2 = 'http://ufldl.stanford.edu/housenumbers/test_32x32.mat'\n",
    "    urllib.request.urlretrieve(url2, './data/test_32x32.mat')\n",
    "    print(\"The dataset was successfully downloaded\")\n",
    "else:\n",
    "    print(\"Not downloading the dataset because it was already downloaded\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The data required for this tutorial comes from from the following sources:-\n",
    "\n",
    "<ul>\n",
    "\n",
    "<li><a href=\"http://ufldl.stanford.edu/housenumbers/train_32x32.mat\">Train Data</li>\n",
    "<li><a href=\"http://ufldl.stanford.edu/housenumbers/test_32x32.mat\">Test Data</li>\n",
    "\n",
    "\n",
    "</ul>\n",
    "\n",
    "The dataset is in MATLAB format \n",
    "\n",
    "This section loads and pre-processes the SHVN dataset and does not have much to do with creating a federated dataset.You can skip the section. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.io import loadmat\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def load_data(path):\n",
    "    \"\"\" Helper function for loading a MAT-File\"\"\"\n",
    "    data = loadmat(path)\n",
    "    return data['X'], data['y']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "train_data , train_labels = load_data(\"data/train_32x32.mat\")\n",
    "test_data , test_labels = load_data(\"data/test_32x32.mat\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(train_data.shape)\n",
    "print(test_data.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice the above Numpy array dimensions are not in appropriate dimensions required for an image. For which we will transpose it to regular image dimensions. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Transpose the image arrays\n",
    "X_train, y_train = train_data.transpose((3,0,1,2)), train_labels[:,0]\n",
    "X_test, y_test = test_data.transpose((3,0,1,2)), test_labels[:,0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h2>Visualize Data</h2>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig=plt.figure(figsize=(8, 8))\n",
    "columns = 4\n",
    "rows = 5\n",
    "for i in range(1,columns*rows+1):\n",
    "    fig.add_subplot(rows, columns, i)\n",
    "    plt.imshow(X_train[i])\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h2>Torch Dataset</h2>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "SVHN dataset is in numpy array , the data could be in Python Array or any other datatype that could converted to torch tensors.      "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SVHNDataset(Dataset):\n",
    "\n",
    "    def __init__(self,images,labels,transform=None):\n",
    "        \n",
    "        \"\"\"Args:\n",
    "             \n",
    "             images (Numpy Array): Image Data\n",
    "             labels (Numpy Array): Labels corresponding to each image\n",
    "             transform (Optional): If any torch transform has to be performed on the dataset\n",
    "             \n",
    "        \"\"\"\n",
    "        \n",
    "        \"Attributes self.data and self.targets must be initialized.\"\n",
    "        \n",
    "        #<--Data must be initialized as self.data,self.train_data or self.test_data\n",
    "        self.data=images\n",
    "        #<--Targets must be initialized as self.targets,self.test_labels or self.train_labels\n",
    "        self.targets=labels\n",
    "        \n",
    "        #<--The data and target must be converted to torch tensors before it is returned by __getitem__ method\n",
    "        self.to_torchtensor()\n",
    "        \n",
    "        #<--If any transforms have to be performed on the dataset\n",
    "        self.transform = transform\n",
    "        \n",
    "        \n",
    "    def to_torchtensor(self):\n",
    "        \n",
    "        \"Transform Numpy Arrays to Torch tensors.\"\n",
    "        \n",
    "        self.data=torch.from_numpy(self.data)\n",
    "        self.labels=torch.from_numpy(self.targets)\n",
    "    \n",
    "        \n",
    "    def __len__(self):\n",
    "        \n",
    "        \"\"\"Required Method\n",
    "            \n",
    "           Returns:\n",
    "        \n",
    "                Length [int]: Length of Dataset/batches\n",
    "        \n",
    "        \"\"\"\n",
    "        \n",
    "        return len(self.data)\n",
    "    \n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        \n",
    "        \"\"\"Required Method\n",
    "        \n",
    "           The output of this method must be torch tensors since torch tensors are overloaded \n",
    "           with share() method which is used to share data to workers.\n",
    "        \n",
    "           Args:\n",
    "                 \n",
    "                 idx [integer]: The index of required batch/example\n",
    "                 \n",
    "           Returns:\n",
    "                 \n",
    "                 Data [Torch Tensor]:     The training examples\n",
    "                 Target [ Torch Tensor]:  Corresponding labels of training examples \n",
    "        \n",
    "        \"\"\"\n",
    "        \n",
    "        sample=self.data[idx]\n",
    "        target=self.targets[idx]\n",
    "                \n",
    "        if self.transform:\n",
    "            sample = self.transform(sample)\n",
    "\n",
    "        return sample,target"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<p style=\"text-align:justify\">Call the federate method with the workers as arguments on the torch dataset instance and provide it as an argument to the federated data loader. This distributes the dataset to the required workers and returns their corresponding pointer tensors. The federated train loader can now be used to load the pointer tensors of corresponding examples and labels iteratively like regular torch data loader. </p>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "federated_SVHN=SVHNDataset(X_train,y_train).federate((bob, alice))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "federated_train_loader = sy.FederatedDataLoader( # <-- this is now a FederatedDataLoader \n",
    "                         federated_SVHN,batch_size=args.batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h2>Syft Base Datasets</h2>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Syft Base Dataset is a simplified dataset feature of Syft Library that allows you create datasets by simply providing training data and corresponding labels. This could also be utilized in federated data loaders. Ensure the inputs to BaseDataset are torch tensors. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base=sy.BaseDataset(torch.from_numpy(X_train),torch.from_numpy(y_train))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_federated=base.federate((bob, alice))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "federated_train_loader = sy.FederatedDataLoader( # <-- this is now a FederatedDataLoader \n",
    "                         base_federated,batch_size=args.batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h1></h1>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Well Done!\n",
    "\n",
    "And voilà! We now know how to create a custom dataset that works with federated data loaders. \n",
    "\n",
    "## Shortcomings of this Example\n",
    "\n",
    "Currently Federated datasets were developed to allow users to perform federated learning easily with federated data loaders. If you have any features that could improve your experience feel free to create an Github Issue. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Congratulations!!! - Time to Join the Community!\n",
    "\n",
    "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement toward privacy preserving, decentralized ownership of AI and the AI supply chain (data), you can do so in the following ways!\n",
    "\n",
    "\n",
    "### Star PySyft on GitHub\n",
    "\n",
    "The easiest way to help our community is just by starring the repositories! This helps raise awareness of the cool tools we're building.\n",
    "\n",
    "- [Star PySyft](https://github.com/OpenMined/PySyft)\n",
    "\n",
    "### Pick our tutorials on GitHub!\n",
    "\n",
    "We made really nice tutorials to get a better understanding of what Federated and Privacy-Preserving Learning should look like and how we are building the bricks for this to happen.\n",
    "\n",
    "- [Checkout the PySyft tutorials](https://github.com/OpenMined/PySyft/tree/master/examples/tutorials)\n",
    "\n",
    "\n",
    "### Join our Slack!\n",
    "\n",
    "The best way to keep up to date on the latest advancements is to join our community! \n",
    "\n",
    "- [Join slack.openmined.org](http://slack.openmined.org)\n",
    "\n",
    "### Join a Code Project!\n",
    "\n",
    "The best way to contribute to our community is to become a code contributor! If you want to start \"one off\" mini-projects, you can go to PySyft GitHub Issues page and search for issues marked `Good First Issue`.\n",
    "\n",
    "- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
    "\n",
    "### Donate\n",
    "\n",
    "If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!\n",
    "\n",
    "- [Donate through OpenMined's Open Collective Page](https://opencollective.com/openmined)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
